import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST

mnist = MNIST(root="./data",train=True,download=True)
print(mnist[0])
# mnist[0][0].show()
ret = transforms.ToTensor()(mnist[0][0])
print(ret.size())








# k = 0
# for i in mnist:
#     i[0].show()
#     k+=1
#     if k == 10:
#         break


